import logging
from copy import deepcopy
from pathlib import Path
from typing import Optional
import os
import requests
import tarfile
import shutil
import numpy as np
import torch
from hydra.utils import call
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import transforms as transform_lib
from torchvision.datasets import ImageFolder
# from filelock import FileLock
from src.data.data_utils import split_subsets_train_val, split_dataset_train_val, add_attrs


class CINIC10DataModule(LightningDataModule):
    name = "CINIC10"

    def __init__(
            self,
            split_function,
            num_classes: int = 10,
            data_dir: str = Path("/tmp"),
            val_split: float = 0.1,
            pub_size: Optional[float] = None,
            num_workers: int = 16,
            normalize: bool = False,
            seed: int = 42,
            batch_size: int = 32,
            num_clients: int = 3,
            fair_val: bool = False,
            *args,
            **kwargs,
    ):
        """
        Args:
            data_dir: where to save/load the data
            val_split: how many of the training images to use for the validation split
            num_workers: how many workers to use for loading data
            normalize: If true applies image normalize
            seed: starting seed for RNG.
            batch_size: desired batch size.
        """
        super().__init__(*args, **kwargs)

        self.data_dir = data_dir
        self.val_split = val_split
        self.pub_size = pub_size
        self.num_workers = num_workers
        self.normalize = normalize
        self.seed = seed
        self.batch_size = batch_size
        self.num_clients = num_clients
        self.fair_val = fair_val
        self.split_function = split_function

        self.num_classes = 10
        self.ds_mean = [0.47889522, 0.47227842, 0.43047404]
        self.ds_std = [0.24205776, 0.23828046, 0.25874835]

        self.datasets_train: [Subset] = ...
        self.datasets_val: [Subset] = ...
        self.train_dataset: Dataset = ...
        self.val_dataset: Dataset = ...
        self.test_dataset: Dataset = ...
        self.pub_train_dataset: Optional[Dataset] = None
        self.pub_val_dataset: Optional[Dataset] = None

        self.current_client_idx = 0

        self.is_setup = False
        self.indices_based = True

    def prepare_data(self):
        """Saves CINIC files to `data_dir`"""
        if self.is_setup:
            return

        def download_dataset(url, target_directory):
            # Create the target directory if it doesn't exist
            os.makedirs(target_directory, exist_ok=True)

            # Check if the dataset is already downloaded
            dataset_file = os.path.join(target_directory, "CINIC-10.tar.gz")
            if os.path.exists(dataset_file):
                print("Dataset already downloaded.")
                return

            # Download the dataset file
            response = requests.get(url)
            with open(dataset_file, "wb") as f:
                f.write(response.content)

            # Extract the contents of the tar file to the target directory
            with tarfile.open(dataset_file, "r:gz") as tar:
                tar.extractall(target_directory)

            # Remove the downloaded tar file
            # os.remove(dataset_file)

        url = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz"
        target_directory = f"{self.data_dir}/cinic-10"

        # Download the dataset
        logging.info("Downloading dataset...")
        download_dataset(url, target_directory)

        # Create train_val folder
        def combine_directories(source_dir1, source_dir2, destination_dir):
            # Create the destination directory if it doesn't exist
            os.makedirs(destination_dir, exist_ok=True)

            # Copy the contents of source_dir1 to the destination directory
            for root, dirs, files in os.walk(source_dir1):
                relative_dir = os.path.relpath(root, source_dir1)
                destination_subdir = os.path.join(destination_dir, relative_dir)
                os.makedirs(destination_subdir, exist_ok=True)

                for file in files:
                    source_file = os.path.join(root, file)
                    destination_file = os.path.join(destination_subdir, file)
                    shutil.copy2(source_file, destination_file)

            # Copy the contents of source_dir2 to the destination directory
            for root, dirs, files in os.walk(source_dir2):
                relative_dir = os.path.relpath(root, source_dir2)
                destination_subdir = os.path.join(destination_dir, relative_dir)
                os.makedirs(destination_subdir, exist_ok=True)

                for file in files:
                    source_file = os.path.join(root, file)
                    destination_file = os.path.join(destination_subdir, file)
                    shutil.copy2(source_file, destination_file)

        # Define the source directories (train and val) and the destination directory (train_val)
        source_dir1 = f"{target_directory}/train"
        source_dir2 = f"{target_directory}/valid"
        destination_dir = f"{target_directory}/train_val"

        if os.path.exists(destination_dir):
            print("Dataset already combined.")
            return
        # Combine the directories
        logging.info("Combining directories...")
        combine_directories(source_dir1, source_dir2, destination_dir)

    def setup(self, stage: Optional[str] = None):
        if self.is_setup:
            print("Setup already been called. Skipping!")
            return

        """Split the train and valid dataset."""
        root = f"{self.data_dir}/cinic-10/train_val"
        self.train_dataset = ImageFolder(
            root=root, transform=self.aug_transforms
        )
        self.val_dataset = ImageFolder(
            root=root, transform=self.default_transforms
        )

        self.train_dataset.targets = torch.Tensor(self.train_dataset.targets).to(torch.long)
        self.val_dataset.targets = torch.Tensor(self.val_dataset.targets).to(torch.long)

        # TODO: add public dataset here
        if self.pub_size:
            self.train_dataset, self.pub_train_dataset = split_dataset_train_val(
                train_dataset=self.train_dataset,
                val_split=self.pub_size,
                seed=self.seed,
            )

            pub_train_datasets, pub_val_datasets = split_subsets_train_val(
                subsets=[self.pub_train_dataset],
                val_precent=self.val_split,
                seed=self.seed,
                val_dataset=self.val_dataset
            )

            self.pub_train_dataset = pub_train_datasets[0]
            self.pub_val_dataset = pub_val_datasets[0]

        if self.fair_val:
            train_subset, val_subset = split_dataset_train_val(
                train_dataset=self.train_dataset,
                val_split=self.val_split,
                seed=self.seed,
                val_dataset=self.val_dataset
            )
            self.datasets_train = call(
                self.split_function, dataset=train_subset)# , min_size_of_dataset=self.batch_size)
            self.datasets_val = [deepcopy(val_subset) for _ in range(self.num_clients)]
            add_attrs(self.datasets_train, self.datasets_val)
        else:
            subsets = call(
                self.split_function, dataset=self.train_dataset,
                min_size_of_dataset=self.batch_size + int(np.ceil(self.batch_size * self.val_split))
            )
            # results is # [train1, t2, ..., tn], [val1, v2, ..., vn]
            self.datasets_train, self.datasets_val = split_subsets_train_val(
                subsets, self.val_split, self.seed, val_dataset=self.val_dataset
            )

        self.is_setup = True

    def transfer_setup(self):
        root = f"{self.data_dir}/cinic-10/train_val"
        self.train_dataset = ImageFolder(
            root=root, transform=self.aug_transforms
        )
        self.val_dataset = ImageFolder(
            root=root,
            transform=self.default_transforms
        )
        self.test_dataset = ImageFolder(
            root=f"{self.data_dir}/cinic-10/test",
            transform=self.default_transforms
        )

    def next_client(self):
        self.current_client_idx += 1
        assert self.current_client_idx < self.num_clients, "Client number shouldn't excced seleced number of clients"

    def train_dataloader(self):
        # check this: https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html#multiple-training-dataloaders
        """CIFAR train set removes a subset to use for validation."""
        loader = DataLoader(
            self.datasets_train[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return loader

    def val_dataloader(self):
        """CIFAR val set uses a subset of the training set for validation."""
        loader = DataLoader(
            self.datasets_val[self.current_client_idx],
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    def test_dataloader(self):
        """CIFAR test set uses the test split."""
        dataset = ImageFolder(
            root=f"{self.data_dir}/cinic-10/test",
            transform=self.default_transforms
        )
        dataset.targets = torch.Tensor(dataset.targets).to(torch.long)
        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=False,
            pin_memory=True,
        )
        return loader

    @property
    def default_transforms(self):
        cifar_transforms = [
            transform_lib.ToTensor(),
        ]
        if self.normalize:
            cifar_transforms.append(transform_lib.Normalize(mean=self.ds_mean,
                                                            std=self.ds_std))

        return transform_lib.Compose(cifar_transforms)

    @property
    def aug_transforms(self):
        cifar_transforms = [
            transform_lib.RandomCrop(32, padding=4),
            transform_lib.RandomHorizontalFlip(),
            transform_lib.ToTensor(),

        ]
        if self.normalize:
            cifar_transforms.append(transform_lib.Normalize(mean=self.ds_mean,
                                                            std=self.ds_std))

        return transform_lib.Compose(cifar_transforms)
